import torch
from src.wasserstein import *
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from src.utils import *


def ood_gan_d_loss(logits_real, logits_fake, logits_ood, labels_real):
    # 1: CrossEntropy of X_in
    criterion = nn.CrossEntropyLoss()
    ind_ce_loss = criterion(logits_real, labels_real)
    # 2. W_ood
    assert logits_ood.requires_grad
    w_ood = batch_wasserstein(logits_ood)
    # 3. W_z
    assert logits_fake.requires_grad
    w_fake = batch_wasserstein(logits_fake)
    return ind_ce_loss, w_ood, w_fake


def ood_gan_g_loss(logits_fake):
    # 1. Wasserstein score of G(z)
    assert logits_fake.requires_grad
    w_fake = batch_wasserstein(logits_fake)
    return w_fake


class SEEOOD_TRAINER():
    def __init__(self, D, G, noise_dim, num_classes,
                 bsz_tri, g_steps_ratio, d_steps_ratio,
                 hp, max_epochs, ood_bsz,
                 writer_name, ckpt_name, ckpt_dir,
                 n_epochs_save=2, n_steps_log=1):
        super().__init__()
        # Logger information
        self.writer_name = writer_name
        self.writer = SummaryWriter(writer_name)
        self.ckpt_name = ckpt_name
        self.ckpt_dir = ckpt_dir
        # Print statement config
        self.n_epochs_save = n_epochs_save
        self.n_steps_log = n_steps_log
        # Backbone models & info
        self.D = D
        self.G = G
        self.num_classes = num_classes
        self.noise_dim = noise_dim
        self.dloss = ood_gan_d_loss
        self.gloss = ood_gan_g_loss
        # Training config
        self.bsz_tri = bsz_tri
        self.d_steps_ratio = d_steps_ratio
        self.g_steps_ratio = g_steps_ratio
        self.max_epochs = max_epochs
        self.hp = hp
        self.ood_bsz = ood_bsz

    def train(self, ind_loader, ood_samples, D_solver, G_solver):
        # Print out OoD sample statistics
        print(f"OoD sample shape: {ood_samples.shape}")
        ood_samples = ood_samples.to(DEVICE)
        n_ood = len(ood_samples)
        ood_bsz = min(len(ood_samples), self.ood_bsz)

        # START TRAINING
        iter_count = 0
        for epoch in range(self.max_epochs):
            for steps, (x, y) in enumerate(tqdm(ind_loader)):
                # break # for testing purpose
                x, y = x.to(DEVICE), y.to(DEVICE)
                # ---------------------- #
                # DISCRIMINATOR TRAINING #
                # ---------------------- #
                for dstep in range(self.d_steps_ratio):
                    D_solver.zero_grad()

                    # InD Classification
                    logits_real = self.D(x)

                    # Adversarial Training
                    seed = torch.rand(
                        (self.bsz_tri, self.noise_dim, 1, 1), device=DEVICE)
                    Gz = self.G(seed)
                    logits_fake = self.D(Gz)

                    # OoD Wasserstein Mapping
                    ood_idx = np.random.choice(n_ood, ood_bsz, replace=False)
                    ood_img = ood_samples[ood_idx, :, :, :].to(DEVICE)
                    logits_ood = self.D(ood_img)

                    # Overall Loss Function
                    ce, w_ood, w_fake = self.dloss(
                        logits_real, logits_fake, logits_ood, y)
                    d_total = ce - self.hp['ood'] * \
                        w_ood + self.hp['z'] * w_fake

                    # Write relevant statistics
                    global_step_d = steps * self.d_steps_ratio + dstep
                    self.writer.add_scalars("Discriminator Loss/each", {
                        'CE': ce.detach(),
                        'W_ood': w_ood.detach(),
                        'W_z': w_fake.detach()
                    }, global_step_d)
                    self.writer.add_scalar(
                        "Discriminator Loss/total", d_total.detach(), global_step_d)

                    # Gradient Update
                    d_total.backward()
                    D_solver.step()

                # ------------------ #
                # GENERATOR TRAINING #
                # ------------------ #
                for gstep in range(self.g_steps_ratio):
                    G_solver.zero_grad()

                    # OoD Adversarial Training
                    seed = torch.rand(
                        (self.bsz_tri, self.noise_dim, 1, 1), device=DEVICE)
                    Gz = self.G(seed)
                    logits_fake = self.D(Gz)

                    w_z = ood_gan_g_loss(logits_fake)
                    g_total = -self.hp['z'] * w_z

                    # Write relevant statistics
                    global_step_g = steps * self.g_steps_ratio + gstep

                    self.writer.add_scalars("Generator Loss/each", {
                        'W_z': w_z.detach()
                    }, global_step_g)
                    self.writer.add_scalar(
                        "Generator Loss/total", g_total.detach(), global_step_g)

                    # Gradient Update
                    g_total.backward()
                    G_solver.step()

                # Print out statistics
                if (iter_count % self.n_steps_log == 0):
                    print(
                        f"Step: {steps:<4} | D: {d_total.item(): .4f} | CE: {ce.item(): .4f} | W_OoD: {w_ood.item(): .4f} | W_z: {w_fake.item(): .4f} | G: {g_total.item(): .4f} | W_z: {w_z.item(): .4f}")
                iter_count += 1

            # Save checkpoint
            if (epoch+1 % self.n_epochs_save) == 0:
                ckpt_name = f"{self.ckpt_dir}{self.ckpt_name}_[{epoch+1}].pt"
                torch.save({
                    'D-state': self.D.state_dict(),
                    'G-state': self.G.state_dict()
                }, ckpt_name)
                print(f'New checkpoint created at the end of epoch {epoch+1}.')
